{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n8VjWQYXRIvE"
      },
      "source": [
        "```\n",
        "- Copyright 2023 DeepMind Technologies Limited\n",
        "- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0\n",
        "- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode\n",
        "- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.\n",
        "- This is not an official Google product\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LVKsiNuyfoNZ"
      },
      "source": [
        "# Admissible set\n",
        "\n",
        "This notebook contains\n",
        "\n",
        "1. the *skeletons* we used for FunSearch to discover large admissible sets,\n",
        "2. the *functions* discovered by FunSearch that construct large admissible sets.\n",
        "\n",
        "## Skeleton\n",
        "\n",
        "This skeleton searches for unrestricted constant-weight admissible sets. The commented-out decorators are just a way to indicate the main entry point of the program (`@funsearch.run`) and the function that *FunSearch* should evolve (`@funsearch.evolve`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X-mpL0261pBp"
      },
      "outputs": [],
      "source": [
        "\"\"\"Finds large admissible sets.\"\"\"\n",
        "import itertools\n",
        "import math\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "def block_children(scores: np.ndarray,\n",
        "                   admissible_set: np.ndarray,\n",
        "                   new_element: np.ndarray) -\u003e None:\n",
        "  \"\"\"Modifies `scores` to -inf for elements blocked by `new_element`.\"\"\"\n",
        "  n = admissible_set.shape[-1]\n",
        "  powers = 3 ** np.arange(n - 1, -1, -1)\n",
        "\n",
        "  invalid_vals_raw = {\n",
        "      (0, 0): (0,),\n",
        "      (0, 1): (1,),\n",
        "      (0, 2): (2,),\n",
        "      (1, 0): (1,),\n",
        "      (1, 1): (0, 1, 2),\n",
        "      (1, 2): (1, 2),\n",
        "      (2, 0): (2,),\n",
        "      (2, 1): (1, 2),\n",
        "      (2, 2): (0, 1, 2),\n",
        "  }\n",
        "  invalid_vals = [[np.array(invalid_vals_raw[(i, j)], dtype=np.int32)\n",
        "                   for j in range(3)] for i in range(3)]\n",
        "\n",
        "  # Block 2**w elements with the same support as `new_element`.\n",
        "  w = np.count_nonzero(new_element)\n",
        "  all_12s = np.array(list(itertools.product((1, 2), repeat=w)), dtype=np.int32)\n",
        "  blocking = np.einsum('aw,w-\u003ea', all_12s, powers[new_element != 0])\n",
        "  scores[blocking] = -np.inf\n",
        "\n",
        "  # Block elements disallowed by a pair of an extant point and `new_element`.\n",
        "  for extant_element in admissible_set:\n",
        "    blocking = np.zeros(shape=(1,), dtype=np.int32)\n",
        "    for e1, e2, power in zip(extant_element, new_element, powers):\n",
        "      blocking = (blocking[:, None] + (invalid_vals[e1][e2] * power)[None, :]\n",
        "                  ).ravel()\n",
        "    scores[blocking] = -np.inf\n",
        "\n",
        "\n",
        "def solve(n: int, w: int) -\u003e np.ndarray:\n",
        "  \"\"\"Generates a constant-weight admissible set I(n, w).\"\"\"\n",
        "  children = np.array(list(itertools.product((0, 1, 2), repeat=n)),\n",
        "                      dtype=np.int32)\n",
        "\n",
        "  scores = -np.inf * np.ones((3 ** n,), dtype=np.float32)\n",
        "  for child_index, child in enumerate(children):\n",
        "    if sum(child == 0) == n - w:\n",
        "      scores[child_index] = priority(np.array(child), n, w)\n",
        "\n",
        "  max_admissible_set = np.empty((0, n), dtype=np.int32)\n",
        "  while np.any(scores != -np.inf):\n",
        "    # Find element with largest score.\n",
        "    max_index = np.argmax(scores)\n",
        "    child = children[max_index]\n",
        "    block_children(scores, max_admissible_set, child)\n",
        "    max_admissible_set = np.concatenate([max_admissible_set, child[None]],\n",
        "                                        axis=0)\n",
        "\n",
        "  return max_admissible_set\n",
        "\n",
        "\n",
        "# @funsearch.run\n",
        "def evaluate(n: int, w: int) -\u003e int:\n",
        "  \"\"\"Returns the size of the constructed admissible set.\"\"\"\n",
        "  return len(solve(n, w))\n",
        "\n",
        "\n",
        "# @funsearch.evolve\n",
        "def priority(el: tuple[int, ...], n: int, w: int) -\u003e float:\n",
        "  \"\"\"Returns the priority with which we want to add `el` to the set.\"\"\"\n",
        "  return 0.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "88hCaeRa1yfs"
      },
      "source": [
        "By executing the skeleton with the trivial `priority` function in place we can for example check that in $n = 12$ dimensions and for weight $w = 7$ it only constructs an admissible set of size $548 \u003c 792 = {12 \\choose 7}$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t_YXZ2ld16gd",
        "outputId": "458a2750-e51c-44bf-f05a-e9c03c5de971"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "548\n"
          ]
        }
      ],
      "source": [
        "print(evaluate(n=12, w=7))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mB3NGXyT2ctY"
      },
      "source": [
        "## Discovered function that builds a full-size $I(12, 7)$ admissible set\n",
        "\n",
        "This function discovered by FunSearch results in a full-sized admissible set $I(12, 7)$, i.e. of size ${12 \\choose 7} = 792$:\n",
        "\n",
        "*Note*: Executing this cell takes around 1 minute."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n4DOhqEM2iuA"
      },
      "outputs": [],
      "source": [
        "def priority(el: tuple[int, ...], n: int, w: int) -\u003e float:\n",
        "  score = 0.0\n",
        "  for i in range(n):\n",
        "    if el[i] == 1:\n",
        "      score -= 0.9 ** (i % 4)\n",
        "    if el[i] == 2:\n",
        "      score -= 0.98 ** (30 - (i % 4))\n",
        "    if el[i] == 1 and el[i - 4] == 1:\n",
        "      score -= 0.98 ** (30 - (i % 4))\n",
        "    if el[i] == 2 and el[i - 4] != 0:\n",
        "      score -= 0.98 ** (30 - (i % 4))\n",
        "    if el[i] == 2 and el[i - 4] == 1 and el[i - 8] == 2:\n",
        "      score -= 0.98 ** (30 - (i % 4))\n",
        "      score -= 6.3\n",
        "    if el[i] == 2 and el[i - 4] == 2 and el[i - 8] == 1:\n",
        "      score -= 0.98 ** (30 - (i % 4))\n",
        "    if el[i] == 2 and el[i - 4] == 1 and el[i - 8] == 1:\n",
        "      score -= 6.3\n",
        "    if el[i] == 2 and el[i - 4] == 0 and el[i - 8] == 2:\n",
        "      score -= 6.3\n",
        "    if el[i] == 1 and el[i - 4] == 1 and el[i - 8] == 0:\n",
        "      score -= 2.2\n",
        "  return score\n",
        "\n",
        "\n",
        "admissible_12_7 = solve(12, 7)\n",
        "assert admissible_12_7.shape == (math.comb(12, 7), 12)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylCu3kQTWGC-"
      },
      "source": [
        "This admissible set already implies an improved lower bound on the cap set capacity compared to the previous state-of-the-art of `2.218021`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "na5pU5DWWIKs",
        "outputId": "84f9cd21-7ed2-4524-c085-8f344ba55a26"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "2.2184431522494577"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "def compute_capacity_bound(n: int, w: int, size: int, m: int) -\u003e float:\n",
        "  \"\"\"Returns the lower bound on the cap set capacity.\n",
        "\n",
        "  We use discovered admissible sets A(n, w) to construct large cap sets,\n",
        "  following a recipe analogous to [Edel, 2004] and [Tyrrell, 2022]:\n",
        "  1. Start with the extendable collection E1 = (A0, A1, A2) of three\n",
        "     6-dimensional cap sets of respective sizes (a0, a1, a2) = (12, 112, 112).\n",
        "  2. Apply a recursively admissible set I(m, m - 1) to E1, which results in a\n",
        "     new extendable collection E2 = (B0, B1, B2) of three 6*m-dimensional cap\n",
        "     sets of sizes (b0, b1, b2) = (a0 * m * a1 ** (m - 1), a1 ** m, a1 ** m).\n",
        "  3. Apply the admissible set A(n, w) of size `size` to E2, which results in a\n",
        "     6*m*n-dimensional cap set C of size `size * (b0 ** (n - w)) * (b1 ** w)`.\n",
        "\n",
        "  Args:\n",
        "    n: Dimensionality of the discovered admissible set A(n, w).\n",
        "    w: The weight of the vectors in the discovered admissible set A(n, w).\n",
        "    size: The size |A(n, w)| of the discovered admissible set.\n",
        "    m: Dimensionality of the recursively admissible set I(m, m - 1) to use.\n",
        "  \"\"\"\n",
        "  a0, a1, _ = (12, 112, 112)\n",
        "  b0 = m * a0 * (a1 ** (m - 1))\n",
        "  b1 = a1 ** m\n",
        "  log_cap_set_size = np.log(size) + (n - w) * np.log(b0) + w * np.log(b1)\n",
        "  log_capacity = log_cap_set_size / (6 * m * n)\n",
        "  return np.exp(log_capacity)\n",
        "\n",
        "\n",
        "compute_capacity_bound(12, 7, len(admissible_12_7), m=7)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1B4KDG454YrI"
      },
      "source": [
        "Furthermore, we can notice that this discovered *program* for $I(12, 7)$ treats the `n` coordinates in a highly-symmetric way: for `n = 12` the four triples of coordinates `{0, 4, 8}`, `{1, 5, 9}`, `{2, 6, 10}`, and `{3, 7, 11}` are treated together. We can verify that the constructed admissible set is preserved under independent cyclic permutations of coordinates within each of these four triples:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qxvmmTDC5Hhj"
      },
      "outputs": [],
      "source": [
        "def get_cyclic_permutations(partition: list[list[int]]) -\u003e set[tuple[int, ...]]:\n",
        "  \"\"\"Returns all combinations of cyclic permutations within `partition`.\"\"\"\n",
        "  identity_permutation = list(range(sum(map(len, partition))))\n",
        "  permutations = set()\n",
        "  for cyclic_shifts in itertools.product(*[range(len(g)) for g in partition]):\n",
        "    permutation = list(identity_permutation)\n",
        "    for group, cyclic_shift in zip(partition, cyclic_shifts):\n",
        "      for i, x in enumerate(group):\n",
        "        permutation[x] = group[(i + cyclic_shift) % len(group)]\n",
        "    permutations.add(tuple(permutation))\n",
        "  return permutations\n",
        "\n",
        "\n",
        "# Obtain all independent cyclic permutations of coordinates within each of the\n",
        "# following four triples of coordinates. There are 3**4=81 such permutations.\n",
        "coordinate_triples = [[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]]\n",
        "permutations = get_cyclic_permutations(coordinate_triples)\n",
        "assert len(permutations) == 3 ** 4\n",
        "\n",
        "# Check that permuting coordinates in any of these 81 ways preserves the\n",
        "# admissible set as a set of vectors, i.e. up to the order of its rows.\n",
        "original_set = set(map(tuple, admissible_12_7))\n",
        "for permutation in permutations:\n",
        "  permuted_set = set(map(tuple, admissible_12_7[:, permutation]))\n",
        "  assert original_set == permuted_set"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-69mZ9LY1lVZ"
      },
      "source": [
        "This observation motivates directly searching for admissible sets that possess such a symmetry; we call such admissible sets *symmetric*. This is a more restricted but also a much smaller search space, which allows us to scale to much larger dimensions `n`. See our paper for more details.\n",
        "\n",
        "## Skeleton for *symmetric* admissible set\n",
        "\n",
        "The commented-out decorators are just a way to indicate the main entry point of the program (`@funsearch.run`) and the function that *FunSearch* should evolve (`@funsearch.evolve`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zZ0fAe6flO_"
      },
      "outputs": [],
      "source": [
        "\"\"\"Finds large symmetric admissible sets.\"\"\"\n",
        "import itertools\n",
        "import math\n",
        "import numpy as np\n",
        "\n",
        "TRIPLES = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 2), (0, 2, 1), (1, 1, 1), (2, 2, 2)]\n",
        "INT_TO_WEIGHT = [0, 1, 1, 2, 2, 3, 3]\n",
        "\n",
        "\n",
        "def expand_admissible_set(\n",
        "    pre_admissible_set: list[tuple[int, ...]]) -\u003e list[tuple[int, ...]]:\n",
        "  \"\"\"Expands a pre-admissible set into an admissible set.\"\"\"\n",
        "  num_groups = len(pre_admissible_set[0])\n",
        "  admissible_set = []\n",
        "  for row in pre_admissible_set:\n",
        "    rotations = [[] for _ in range(num_groups)]\n",
        "    for i in range(num_groups):\n",
        "      x, y, z = TRIPLES[row[i]]\n",
        "      rotations[i].append((x, y, z))\n",
        "      if not x == y == z:\n",
        "        rotations[i].append((z, x, y))\n",
        "        rotations[i].append((y, z, x))\n",
        "    product = list(itertools.product(*rotations))\n",
        "    concatenated = [sum(xs, ()) for xs in product]\n",
        "    admissible_set.extend(concatenated)\n",
        "  return admissible_set\n",
        "\n",
        "\n",
        "def get_surviving_children(extant_elements, new_element, valid_children):\n",
        "  \"\"\"Returns the indices of `valid_children` that remain valid after adding `new_element` to `extant_elements`.\"\"\"\n",
        "  bad_triples = set([\n",
        "      (0, 0, 0), (0, 1, 1), (0, 2, 2), (0, 3, 3), (0, 4, 4), (0, 5, 5),\n",
        "      (0, 6, 6), (1, 1, 1), (1, 1, 2), (1, 2, 2), (1, 2, 3), (1, 2, 4),\n",
        "      (1, 3, 3), (1, 4, 4), (1, 5, 5), (1, 6, 6), (2, 2, 2), (2, 3, 3),\n",
        "      (2, 4, 4), (2, 5, 5), (2, 6, 6), (3, 3, 3), (3, 3, 4), (3, 4, 4),\n",
        "      (3, 4, 5), (3, 4, 6), (3, 5, 5), (3, 6, 6), (4, 4, 4), (4, 5, 5),\n",
        "      (4, 6, 6), (5, 5, 5), (5, 5, 6), (5, 6, 6), (6, 6, 6)])\n",
        "\n",
        "  # Compute.\n",
        "  valid_indices = []\n",
        "  for index, child in enumerate(valid_children):\n",
        "    # Invalidate based on 2 elements from `new_element` and 1 element from a\n",
        "    # potential child.\n",
        "    if all(INT_TO_WEIGHT[x] \u003c= INT_TO_WEIGHT[y]\n",
        "           for x, y in zip(new_element, child)):\n",
        "      continue\n",
        "    # Invalidate based on 1 element from `new_element` and 2 elements from a\n",
        "    # potential child.\n",
        "    if all(INT_TO_WEIGHT[x] \u003e= INT_TO_WEIGHT[y]\n",
        "           for x, y in zip(new_element, child)):\n",
        "      continue\n",
        "    # Invalidate based on 1 element from `extant_elements`, 1 element from\n",
        "    # `new_element`, and 1 element from a potential child.\n",
        "    is_invalid = False\n",
        "    for extant_element in extant_elements:\n",
        "      if all(tuple(sorted((x, y, z))) in bad_triples\n",
        "             for x, y, z in zip(extant_element, new_element, child)):\n",
        "        is_invalid = True\n",
        "        break\n",
        "    if is_invalid:\n",
        "      continue\n",
        "\n",
        "    valid_indices.append(index)\n",
        "  return valid_indices\n",
        "\n",
        "\n",
        "def solve(n: int, w: int) -\u003e tuple[np.ndarray, np.ndarray]:\n",
        "  \"\"\"Generates a symmetric constant-weight admissible set I(n, w).\"\"\"\n",
        "  num_groups = n // 3\n",
        "  assert 3 * num_groups == n\n",
        "\n",
        "  # Compute the scores of all valid (weight w) children.\n",
        "  valid_children = []\n",
        "  for child in itertools.product(range(7), repeat=num_groups):\n",
        "    weight = sum(INT_TO_WEIGHT[x] for x in child)\n",
        "    if weight == w:\n",
        "      valid_children.append(np.array(child, dtype=np.int32))\n",
        "  valid_scores = np.array([\n",
        "      priority(sum([TRIPLES[x] for x in xs], ()), n, w)\n",
        "      for xs in valid_children])\n",
        "\n",
        "  # Greedy search guided by the scores.\n",
        "  pre_admissible_set = np.empty((0, num_groups), dtype=np.int32)\n",
        "  while valid_children:\n",
        "    max_index = np.argmax(valid_scores)\n",
        "    max_child = valid_children[max_index]\n",
        "    surviving_indices = get_surviving_children(pre_admissible_set, max_child,\n",
        "                                               valid_children)\n",
        "    valid_children = [valid_children[i] for i in surviving_indices]\n",
        "    valid_scores = valid_scores[surviving_indices]\n",
        "\n",
        "    pre_admissible_set = np.concatenate([pre_admissible_set, max_child[None]],\n",
        "                                        axis=0)\n",
        "\n",
        "  return pre_admissible_set, np.array(expand_admissible_set(pre_admissible_set))\n",
        "\n",
        "\n",
        "# @funsearch.run\n",
        "def evaluate(n: int, w: int) -\u003e int:\n",
        "  \"\"\"Returns the size of the expanded admissible set.\"\"\"\n",
        "  _, admissible_set = solve(n, w)\n",
        "  return len(admissible_set)\n",
        "\n",
        "\n",
        "# @funsearch.evolve\n",
        "def priority(el: tuple[int, ...], n: int, w: int) -\u003e float:\n",
        "  \"\"\"Returns the priority with which we want to add `el` to the set.\"\"\"\n",
        "  return 0.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QY5jPdo-g1fT"
      },
      "source": [
        "By executing the skeleton with the trivial `priority` function in place we can for example check that in $n = 15$ dimensions and for weight $w = 10$ it only constructs an admissible set of size $1842 \u003c 3003 = {15 \\choose 10}$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1cLP6xvzfn1k",
        "outputId": "ddda94a8-fc23-4464-a3a7-faee4f37f203"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1842\n"
          ]
        }
      ],
      "source": [
        "print(evaluate(n=15, w=10))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I9-mf0aThXQl"
      },
      "source": [
        "## Discovered function that builds a full-size $I(15, 10)$ admissible set\n",
        "\n",
        "This function discovered by FunSearch results in a full-sized admissible set $I(15, 10)$, i.e. of size ${15 \\choose 10} = 3003$:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k-k8WyrohG8I",
        "outputId": "ba1f6409-73f7-446d-f980-7c332238e595"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.2194858716370582\n"
          ]
        }
      ],
      "source": [
        "def priority(el: tuple[int, ...], n: int, w: int) -\u003e float:\n",
        "  score = 0.0\n",
        "  for i in range(n):\n",
        "    if el[i] \u003c el[i - 1]:\n",
        "      score += 1\n",
        "    elif el[i] \u003c el[i - 2]:\n",
        "      score += 0.05\n",
        "    elif el[i] \u003c el[i - 3]:\n",
        "      score -= 0.05\n",
        "    elif el[i] \u003c el[i - 4]:\n",
        "      score += 0.01\n",
        "    elif el[i] \u003c el[i - 5]:\n",
        "      score -= 0.01\n",
        "    elif el[i] \u003c el[i - 6]:\n",
        "      score += 0.001\n",
        "    else:\n",
        "      score += 0.005\n",
        "\n",
        "  for i in range(n):\n",
        "    if el[i] == el[i - 1]:\n",
        "      score -= w\n",
        "    elif el[i] == 0 and i != n - 1 and el[i + 1] != 0:\n",
        "      score += w\n",
        "    if el[i] != el[i - 1]:\n",
        "      score += w\n",
        "\n",
        "  for i in range(n):\n",
        "    if el[i] \u003c el[i - 1]:\n",
        "      if el[i] == 0:\n",
        "        score -= w\n",
        "  return score\n",
        "\n",
        "\n",
        "pre_admissible_15_10, admissible_15_10 = solve(15, 10)\n",
        "assert admissible_15_10.shape == (math.comb(15, 10), 15)\n",
        "assert pre_admissible_15_10.shape == (101, 5)\n",
        "\n",
        "# Show the resulting lower bound on the cap set capacity.\n",
        "print(compute_capacity_bound(15, 10, len(admissible_15_10), m=5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jKTDz_ChifB8"
      },
      "source": [
        "## Discovered function that builds a size $43\\,596$ admissible set in $A(21, 15)$\n",
        "\n",
        "This admissible set implies an improved lower bound of $2.220041$ on the cap set capacity.\n",
        "\n",
        "*Note*: After uncommenting the invocation below, executing this cell can take ~15 minutes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uOz3hX0YiWKd"
      },
      "outputs": [],
      "source": [
        "def priority(el: tuple[int, ...], n: int, w: int) -\u003e float:\n",
        "  score = 0\n",
        "  coeff = 0\n",
        "  for pos, x in zip(range(n), el):\n",
        "    y = (el[(pos + 1) % n] - el[(pos)]) % n\n",
        "    z = (el[(pos + 2) % n] - el[(pos)]) % n\n",
        "    p = (el[(pos - 1) % n] + 1) % n\n",
        "\n",
        "    u = (el[(pos - 2) % n] + 1) % n\n",
        "    v = (el[(pos + 3) % n] + 1) % n\n",
        "\n",
        "    score += 3 * p * (p + coeff) * (p + w) + (p + coeff)**2 * (w + 1)\n",
        "    score += 2 * p * v * (p + w) + v * z * (-1 + w) - (p + coeff) * (-1 + w)\n",
        "    score += v * (u + w) + u + 3 * u * y * (1 + w) + u * z * (w - 1) - (p + coeff) * (w - 1)\n",
        "    score += (1 + w)**6 * 3 * coeff**2\n",
        "\n",
        "  return score\n",
        "\n",
        "\n",
        "# Uncomment to execute; note it can take ~15 minutes to run this.\n",
        "# pre_admissible_21_15, admissible_21_15 = solve(21, 15)\n",
        "# assert admissible_21_15.shape == (43_596, 21)\n",
        "# assert pre_admissible_21_15.shape == (308, 7)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y3h2fjO6WWXe",
        "outputId": "48a0e71b-1755-4ca5-8d0a-c82aa10f781e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.220040576961921\n"
          ]
        }
      ],
      "source": [
        "# Show the resulting lower bound on the cap set capacity.\n",
        "print(compute_capacity_bound(21, 15, 43_596, m=4))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_7KwzMAkdoJ"
      },
      "source": [
        "## Discovered function that builds a size $237\\,984$ admissible set in $A(24, 17)$\n",
        "\n",
        "This admissible set implies a new state-of-the-art lower bound of $2.220234$ on the cap set capacity.\n",
        "\n",
        "*Note*: Running `solve(24, 17)` takes ~7 hours. In practice we used a C++ reimplementation of the function `get_surviving_children` to speed things up. Note that we provide the result of running this code as standalone files `pre_admissible_set_for_n24_w17_size237984.txt` and `admissible_set_n24_w17_size237984.txt` in the same directory as this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KNBM5TBwkghi"
      },
      "outputs": [],
      "source": [
        "def priority(el: tuple[int, ...], n: int, w: int) -\u003e float:\n",
        "  result = 0.0\n",
        "  for i in range(n):\n",
        "    n_violations = 0\n",
        "\n",
        "    if el[i] \u003c el[i - 1]:\n",
        "      result += (el[i - 1] ** 0.5) * w ** 2 / (6 * 6)\n",
        "      n_violations += 1\n",
        "\n",
        "    if el[i] \u003c el[i - 2]:\n",
        "      result += el[i - 2] ** 0.5\n",
        "      n_violations += 1\n",
        "\n",
        "    if el[i - 1] != 0:\n",
        "      result -= (el[i] - el[i - 1]) * w ** 2 / (6 * 3)\n",
        "      n_violations += 2\n",
        "\n",
        "    if el[i - 2] != 0:\n",
        "      result -= (el[i] - el[i - 2]) * w ** 2 / (6 * 6) * (0.95 ** n_violations)\n",
        "      n_violations += 1\n",
        "\n",
        "    result -= (0.02 ** el[i]) * (el[i] - el[i - 8])\n",
        "\n",
        "  return result\n",
        "\n",
        "\n",
        "# Executing this would take ~7 hours.\n",
        "# pre_admissible_24_17, admissible_24_17 = solve(24, 17)\n",
        "# assert admissible_24_17.shape == (237_984, 24)\n",
        "# assert pre_admissible_24_17.shape == (736, 8)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "txdmW8feWXKg",
        "outputId": "3a1ee385-f4a8-4d05-a1e2-1b2778fd2665"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.2202336538665377\n"
          ]
        }
      ],
      "source": [
        "# Show the resulting lower bound on the cap set capacity.\n",
        "print(compute_capacity_bound(24, 17, 237_984, m=4))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
